package com.example.websocket;

import com.example.config.RateLimitService;
import com.example.util.EnhancedJwtTokenUtil;
import com.example.util.SecurityUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import java.net.URI;
import java.util.Map;

/**
 * WebSocket握手拦截器 - 增强安全版本
 * 用于在WebSocket连接建立前进行身份验证和安全检查
 */
@Slf4j
@Component
public class WebSocketInterceptor implements HandshakeInterceptor {

    @Autowired
    private EnhancedJwtTokenUtil jwtTokenUtil;

    @Autowired
    private SecurityUtils securityUtils;

    @Autowired
    private RateLimitService rateLimitService;

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
                                   WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        
        try {
            // 获取客户端IP
            String clientIp = getClientIp(request);

            // WebSocket连接限流检查
            if (!rateLimitService.isAllowed("websocket:" + clientIp, 10, 60)) {
                log.warn("WebSocket连接被限流: ip={}", clientIp);
                response.setStatusCode(org.springframework.http.HttpStatus.TOO_MANY_REQUESTS);
                return false;
            }

            // 从查询参数中获取token
            String query = request.getURI().getQuery();
            String token = null;

            if (query != null && query.contains("token=")) {
                String[] params = query.split("&");
                for (String param : params) {
                    if (param.startsWith("token=")) {
                        token = java.net.URLDecoder.decode(param.substring(6), "UTF-8"); // 解码token
                        break;
                    }
                }
            }

            if (token == null || token.isEmpty()) {
                log.warn("WebSocket握手失败: 缺少token参数, ip={}", clientIp);
                response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
                return false;
            }

            // 验证token
            if (!jwtTokenUtil.validateToken(token)) {
                log.warn("WebSocket握手失败: token无效, ip={}", clientIp);
                response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
                return false;
            }

            // 从token中获取用户信息
            String username = jwtTokenUtil.getUsernameFromToken(token);
            String userType = jwtTokenUtil.getUserTypeFromToken(token);

            if (username == null) {
                log.warn("WebSocket握手失败: 无法从token获取用户名, ip={}", clientIp);
                response.setStatusCode(org.springframework.http.HttpStatus.UNAUTHORIZED);
                return false;
            }

            // 检查用户连接数限制
            if (!rateLimitService.isAllowed("websocket_user:" + username, 5, 3600)) {
                log.warn("WebSocket连接被拒绝: 用户连接数超限, user={}", username);
                response.setStatusCode(org.springframework.http.HttpStatus.TOO_MANY_REQUESTS);
                return false;
            }

            // 将用户信息存储到WebSocket会话属性中
            attributes.put("userId", username);
            attributes.put("userType", userType);
            attributes.put("token", token);
            attributes.put("clientIp", clientIp);
            attributes.put("connectTime", System.currentTimeMillis());

            log.info("WebSocket握手成功: userId={}, userType={}, ip={}", username, userType, clientIp);
            return true;

        } catch (Exception e) {
            log.error("WebSocket握手异常: {}", e.getMessage(), e);
            response.setStatusCode(org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR);
            return false;
        }
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
                               WebSocketHandler wsHandler, Exception exception) {
        if (exception != null) {
            log.error("WebSocket握手后异常: {}", exception.getMessage(), exception);
        }
    }

    /**
     * 获取客户端IP地址
     */
    private String getClientIp(ServerHttpRequest request) {
        String ip = null;

        // 尝试从各种头部获取真实IP
        if (request.getHeaders().containsKey("X-Forwarded-For")) {
            ip = request.getHeaders().getFirst("X-Forwarded-For");
        }

        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeaders().getFirst("Proxy-Client-IP");
        }

        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeaders().getFirst("WL-Proxy-Client-IP");
        }

        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeaders().getFirst("HTTP_CLIENT_IP");
        }

        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeaders().getFirst("HTTP_X_FORWARDED_FOR");
        }

        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddress() != null ? request.getRemoteAddress().getAddress().getHostAddress() : "unknown";
        }

        // 处理多个IP的情况，取第一个
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }

        return ip != null ? ip : "unknown";
    }
}
